from typing import Iterable, Tuple

from jpype import JArray
from pipelines.prompta.utils import pta_dfa2compact_dfa, show_dfa, tuple2word, word2tuple
from .base_learner import BaseLearner
from prompta.utils.java_libs import BlueFringeEDSMDFA, Arrays, Collectors, Word
from pipelines.prompta.oracle.base_oracle import BaseOracle


class RPNI_EDSM(BaseLearner):
    ID = 'RPNI_EDSM'

    def __init__(self, oracle: BaseOracle, exp_dir: str) -> None:
        super().__init__(oracle, exp_dir)
        self.alphabet = self.oracle.jalphabet
        self.reset()

    def reset(self):
        self.learner = BlueFringeEDSMDFA(self.alphabet)
        self.learner.setParallel(False)
        self.learner.setDeterministic(True)
        self.cache = {}

    def learn(self, active=False):
        if active:
            return self._learn_with_oracle()
        else:
            return self._learn()

    def _learn(self):
        self.hypothesis = self.learner.computeModel()
        self.hypothesis = pta_dfa2compact_dfa(self.hypothesis)

        return self.hypothesis

    def add_examples(self, examples: Iterable[Tuple[str]], positive: bool):
        examples = [tuple2word(word) for word in examples]
        examples = JArray(Word)(examples)
        if positive:
            self.learner.addPositiveSamples(*examples)
        else:
            self.learner.addNegativeSamples(*examples)

    def add_example(self, example):
        self.learner.addSample(example)

    def _learn_with_oracle(self):
        print('learn with oracle')
        while True:
            hypothesis = self.learner.computeModel()
            ce = self.oracle.check_conjecture(hypothesis, 'DefaultQuery')
            if ce is None or len(self.cache) > 2000:
                break
            word = ce.getInput()
            result = ce.getOutput()
            self.cache[word2tuple(word)] = result
            self.learner.addSample(ce)
        self.hypothesis = pta_dfa2compact_dfa(hypothesis)
        return self.hypothesis

